################################################################################
# 06_main_regressions.R
# Main Regression Analyses for JHR Paper
# "Teacher Testing Standards and the New Teacher Pipeline"
# Law, Marks, and Stern
#
# Tables Generated:
#   - Table 4: Composite Enrollments (8 columns)
#   - Table 5: Teacher Preparation Program Graduates (8 columns)
#   - Table A1: Robustness Tests Using Alternative Samples (5 columns each)
#   - Table A2: Event Study Coefficients
#
# Note: Converts Stata reghdfe + pairs cluster bootstrap to
#       R fixest::feols + manual pairs cluster bootstrap
################################################################################

library(tidyverse)
library(readxl)
library(writexl)
library(haven)
library(fixest)
library(modelsummary)
library(broom)

# ── Configuration ─────────────────────────────────────────────────────────────

# Set to TRUE for publication-quality bootstrap SEs (slow: ~20-30 min total)
# Set to FALSE for fast iteration with standard clustered SEs
USE_BOOTSTRAP <- FALSE
BOOT_REPS <- 1000   # matches Stata: bootstrap _b, reps(1000)
BOOT_SEED <- 12345  # matches Stata: seed(12345)

# Output directory
out_dir <- "output/tables"
if (!dir.exists(out_dir)) dir.create(out_dir, recursive = TRUE)

# ── Load Raw Data ─────────────────────────────────────────────────────────────

cat("Loading analysis data...\n")

enrollment_raw <- read_excel("data/cleaned/enrollment_event_data.xlsx")
graduation_raw <- read_excel("data/cleaned/graduation_event_data.xlsx")

cat("  Raw enrollment:", nrow(enrollment_raw), "obs,",
    n_distinct(enrollment_raw$unitid), "universities\n")
cat("  Raw graduation:", nrow(graduation_raw), "obs,",
    n_distinct(graduation_raw$unitid), "universities\n")

# ── Define Constants ──────────────────────────────────────────────────────────

# Shrinking states: derived from pop_percentage_change column in data (< 0)
# Matches Stata: gen shrinking_state = 1 if pop_percentage_change < 0
# Note: DC has pop_percentage_change > 0 (growing), despite being in some hardcoded lists
SHRINKING_STATES <- c("CT", "LA", "ME", "MS", "NH", "NJ", "PA", "VT", "WI", "WV")

# States excluded from main sample
EXCLUDE_STATES <- c("ND", "TN")

# AR and DE test_index fill value for 2018 (carried forward from 2016)
ARDE_FILL_VALUE <- -0.4162782

# SC test_index value after 2017 (lowered scores June 2017, retroactive Sept 2016)
SC_POST2017_VALUE <- -0.6576

# ── Control Variable Definitions ──────────────────────────────────────────────

# State economic controls
state_demo <- c("real_income", "unemployment_rate")

# State education policy controls (Kraft et al. 2020)
state_kraft <- c("passevals", "implementevals", "eliminate_tenure",
                 "increase_probationary_period", "weaken_bargaining",
                 "eliminate_union_dues", "won_race_top", "common_core", "edtpa")

# University controls
uni_controls <- c("test_optional", "satpct2", "actpct2",
                  "satvr25_2", "satvr75_2", "satmt25_2", "satmt75_2",
                  "actcm25_2", "actcm75_2",
                  "pell_percent2", "pell_amount2",
                  "loan_percent2", "loan_average2", "enrollment_total2")

# Full control set
all_controls <- c(state_demo, state_kraft, uni_controls)

# ── Data Preparation ──────────────────────────────────────────────────────────

fix_test_index <- function(data) {
  # Fill forward test_index within each state to handle NAs
  # Known issues: AR/DE have NA in 2018, ME has NA in 2020, VA has NA in 2020
  data %>%
    group_by(State) %>%
    arrange(State, year) %>%
    fill(test_index, .direction = "down") %>%
    ungroup()
}

prepare_enrollment_data <- function(data, exclude_states = EXCLUDE_STATES) {
  data %>%
    # Remove excluded states first
    filter(!State %in% exclude_states) %>%
    # Fill forward test_index within each state
    fix_test_index() %>%
    # Handle CT after 2017: eliminated Praxis Core requirement Sept 2016
    mutate(test_index = if_else(State == "CT" & year > 2017, NA_real_, test_index)) %>%
    # Handle SC after 2017: lowered scores June 2017, retroactive to Sept 2016
    mutate(test_index = if_else(State == "SC" & year > 2017, SC_POST2017_VALUE, test_index)) %>%
    # Create factors for fixed effects
    mutate(
      unitid_f = factor(unitid),
      year_f = factor(year),
      state_f = factor(State),
      shrinking = as.integer(State %in% SHRINKING_STATES)
    ) %>%
    # Drop missing test_index (CT 2018 after adjustment)
    filter(!is.na(test_index))
}

# Build state-level test_index lag for graduation data
# Graduation starts at 2009; we need 2008 test_index from enrollment data for lagging
build_state_lag <- function(grad_data, enroll_data, exclude_states = EXCLUDE_STATES) {
  # Extract state-year test_index from graduation data
  grad_state <- grad_data %>%
    filter(!State %in% exclude_states) %>%
    select(State, year, test_index) %>%
    distinct()

  # Get 2008 test_index from enrollment data (for lagging to 2009)
  enroll_2008 <- enroll_data %>%
    filter(year == 2008, !State %in% exclude_states) %>%
    select(State, year, test_index) %>%
    distinct()

  # Combine and create lag at state level
  state_panel <- bind_rows(enroll_2008, grad_state) %>%
    distinct(State, year, .keep_all = TRUE) %>%
    arrange(State, year) %>%
    group_by(State) %>%
    fill(test_index, .direction = "down") %>%
    mutate(test_index_lag = lag(test_index, 1)) %>%
    ungroup()

  state_panel %>% select(State, year, test_index_lag)
}

prepare_graduation_data <- function(data, enroll_data,
                                     exclude_states = EXCLUDE_STATES) {
  # First fix test_index NAs within each state
  prepped <- data %>%
    filter(!State %in% exclude_states) %>%
    fix_test_index() %>%
    # Handle CT after 2017
    mutate(test_index = if_else(State == "CT" & year > 2017, NA_real_, test_index)) %>%
    # Handle SC after 2017
    mutate(test_index = if_else(State == "SC" & year > 2017, SC_POST2017_VALUE, test_index))

  # Build state-level lag (using 2008 enrollment data for the 2009 lag)
  state_lag <- build_state_lag(prepped, enroll_data, exclude_states)

  prepped %>%
    # Merge state-level lag
    left_join(state_lag, by = c("State", "year")) %>%
    # Create factors for fixed effects
    mutate(
      unitid_f = factor(unitid),
      year_f = factor(year),
      state_f = factor(State),
      shrinking = as.integer(State %in% SHRINKING_STATES)
    ) %>%
    # Ensure l_ctotalt exists
    mutate(l_ctotalt = log(ctotalt + 1))
    # NOTE: Do NOT filter on !is.na(test_index) here.
    # Event study uses year_XXXX vars (not test_index) so needs all rows.
    # TWFE sample filters on !is.na(test_index_lag) later.
}

# Prepare main analysis datasets
enrollment_data <- prepare_enrollment_data(enrollment_raw)
graduation_data <- prepare_graduation_data(graduation_raw, enrollment_raw)

cat("  Enrollment after prep:", nrow(enrollment_data), "obs,",
    n_distinct(enrollment_data$unitid), "universities,",
    n_distinct(enrollment_data$State), "states\n")
cat("  Graduation after prep:", nrow(graduation_data), "obs,",
    n_distinct(graduation_data$unitid), "universities,",
    n_distinct(graduation_data$State), "states\n")

# ── Construct Selectivity for Enrollment Data ─────────────────────────────────
# More selective = SAT verbal 25th pctl OR ACT composite 25th pctl above
# in-sample median as of 2010. Below median or missing = Less Selective.

if (!"selective" %in% names(enrollment_data)) {
  # Try to get selectivity from graduation data (already has it pre-computed)
  if ("selective" %in% names(graduation_data)) {
    sel_map <- graduation_data %>%
      select(unitid, selective) %>%
      distinct() %>%
      group_by(unitid) %>%
      summarize(selective = max(selective, na.rm = TRUE), .groups = "drop")

    enrollment_data <- enrollment_data %>%
      left_join(sel_map, by = "unitid") %>%
      mutate(selective = replace_na(selective, 0L))

    cat("  Selectivity merged from graduation data\n")
  } else {
    # Construct from SAT/ACT scores at 2010
    scores_2010 <- enrollment_data %>%
      filter(year == 2010) %>%
      select(unitid, satvr25, actcm25) %>%
      distinct()

    med_sat <- median(scores_2010$satvr25, na.rm = TRUE)
    med_act <- median(scores_2010$actcm25, na.rm = TRUE)

    scores_2010 <- scores_2010 %>%
      mutate(selective = as.integer(
        (!is.na(satvr25) & satvr25 > med_sat) |
        (!is.na(actcm25) & actcm25 > med_act)
      ))

    enrollment_data <- enrollment_data %>%
      left_join(scores_2010 %>% select(unitid, selective), by = "unitid") %>%
      mutate(selective = replace_na(selective, 0L))

    cat("  Selectivity constructed from 2010 SAT/ACT scores\n")
    cat("    SAT verbal 25th median:", med_sat, ", ACT composite 25th median:", med_act, "\n")
  }
}

cat("  Enrollment selectivity: More=", sum(enrollment_data$selective == 1) / 6,
    " Less=", sum(enrollment_data$selective == 0) / 6, " (approx per year)\n")
cat("  Graduation selectivity: More=", n_distinct(graduation_data$unitid[graduation_data$selective == 1]),
    " Less=", n_distinct(graduation_data$unitid[graduation_data$selective == 0]), "\n")

# ── Helper: Run TWFE Regression with Optional Bootstrap ───────────────────────
# Bootstrap method: pairs cluster bootstrap (resample states with replacement)
# Matches Stata: bootstrap _b, reps(1000) seed(12345) cluster(state) idcluster(statefips)

run_twfe <- function(df, outcome, treatment_var = "test_index",
                     controls = NULL, fe = "unitid_f + year_f",
                     unit_var = "unitid",
                     bootstrap = USE_BOOTSTRAP,
                     B = BOOT_REPS, seed = BOOT_SEED) {

  # Filter to controls that exist in data and are not entirely NA
  valid_controls <- if (!is.null(controls)) controls[controls %in% names(df)] else NULL
  if (!is.null(valid_controls)) {
    valid_controls <- valid_controls[sapply(valid_controls, function(v) !all(is.na(df[[v]])))]
    if (length(valid_controls) == 0) valid_controls <- NULL
  }

  # Build formula
  if (is.null(valid_controls) || length(valid_controls) == 0) {
    fml <- as.formula(paste0(outcome, " ~ ", treatment_var, " | ", fe))
  } else {
    control_str <- paste(valid_controls, collapse = " + ")
    fml <- as.formula(paste0(outcome, " ~ ", treatment_var, " + ",
                              control_str, " | ", fe))
  }

  # Run main regression with standard clustered SEs
  mod <- feols(fml, data = df, cluster = ~State)

  # Extract coefficient and SE
  coef_val <- coef(mod)[treatment_var]
  se_val <- sqrt(vcov(mod)[treatment_var, treatment_var])
  p_val <- 2 * pt(-abs(coef_val / se_val), df = n_distinct(df$State) - 1)

  # Pairs cluster bootstrap (resample entire states with replacement)
  if (bootstrap) {
    set.seed(seed)
    clusters <- unique(df$State)
    G <- length(clusters)
    cluster_list <- split(df, df$State)

    boot_coefs <- numeric(B)
    n_failures <- 0

    for (b in 1:B) {
      sampled <- sample(clusters, G, replace = TRUE)
      boot_data <- do.call(rbind, cluster_list[as.character(sampled)])
      rownames(boot_data) <- NULL

      tryCatch({
        boot_mod <- feols(fml, data = boot_data)
        boot_coefs[b] <- coef(boot_mod)[treatment_var]
      }, error = function(e) {
        boot_coefs[b] <<- NA
        n_failures <<- n_failures + 1
      })
    }

    valid_coefs <- boot_coefs[!is.na(boot_coefs)]
    if (length(valid_coefs) > B * 0.5) {
      se_val <- sd(valid_coefs)
      p_val <- 2 * pnorm(-abs(coef_val / se_val))
      if (n_failures > 0) {
        cat(sprintf("    Note: %d/%d bootstrap reps failed\n", n_failures, B))
      }
    } else {
      warning("Bootstrap failed (>50% reps failed) — using standard clustered SEs")
    }
  }

  # Compute mean of dependent variable in levels
  if (grepl("^l_", outcome)) {
    level_var <- sub("^l_", "", outcome)
    if (level_var %in% names(df)) {
      mean_y <- round(mean(df[[level_var]], na.rm = TRUE), 0)
    } else {
      mean_y <- round(mean(exp(df[[outcome]]) - 1, na.rm = TRUE), 0)
    }
  } else {
    mean_y <- round(mean(df[[outcome]], na.rm = TRUE), 0)
  }

  list(
    model = mod,
    coef = coef_val,
    se = se_val,
    p = p_val,
    n = mod$nobs,
    n_univs = n_distinct(df[[unit_var]]),
    n_clusters = n_distinct(df$State),
    mean_y = mean_y
  )
}

# Helper to format significance stars
stars <- function(p) {
  if (is.na(p)) return("")
  if (p < 0.01) return("**")
  if (p < 0.05) return("*")
  if (p < 0.10) return("+")
  return("")
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE 4: Composite Enrollments (Total Enrollments)
# DV = log(fall education enrollments + 1) = l_eftotlt
# Treatment = test_index (contemporaneous)
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE 4: Composite Enrollments\n")
cat(strrep("=", 60), "\n")

# Col (1): All, no controls
cat("  Running col (1): All, no controls...\n")
t4c1 <- run_twfe(enrollment_data, "l_eftotlt", "test_index", controls = NULL)

# Col (2): All, full controls
cat("  Running col (2): All, full controls...\n")
t4c2 <- run_twfe(enrollment_data, "l_eftotlt", "test_index", controls = all_controls)

# Col (3): Less Selective, full controls
cat("  Running col (3): Less Selective...\n")
t4c3 <- run_twfe(enrollment_data %>% filter(selective == 0),
                  "l_eftotlt", "test_index", controls = all_controls)

# Col (4): More Selective, full controls
cat("  Running col (4): More Selective...\n")
t4c4 <- run_twfe(enrollment_data %>% filter(selective == 1),
                  "l_eftotlt", "test_index", controls = all_controls)

# Col (5): White enrollments, full sample (paper shows N=2,896 = full sample)
cat("  Running col (5): White...\n")
t4c5 <- run_twfe(enrollment_data %>% filter(!is.na(l_efwhitt)),
                  "l_efwhitt", "test_index", controls = all_controls)

# Col (6): Non-White enrollments, full sample
cat("  Running col (6): Non-White...\n")
# Use l_nonwhite (the actual column name in enrollment data)
nw_enroll_var <- if ("l_nonwhite" %in% names(enrollment_data)) "l_nonwhite" else "l_efnwhit"
t4c6 <- run_twfe(enrollment_data %>% filter(!is.na(.data[[nw_enroll_var]])),
                  nw_enroll_var, "test_index", controls = all_controls)

# Col (7): Shrinking State, full controls
cat("  Running col (7): Shrinking State...\n")
t4c7 <- run_twfe(enrollment_data %>% filter(shrinking == 1),
                  "l_eftotlt", "test_index", controls = all_controls)

# Col (8): Growing State, full controls
cat("  Running col (8): Growing State...\n")
t4c8 <- run_twfe(enrollment_data %>% filter(shrinking == 0),
                  "l_eftotlt", "test_index", controls = all_controls)

# Compile Table 4 results
table4 <- tibble(
  Column = paste0("(", 1:8, ")"),
  Sample = c("All", "All", "Less Selective", "More Selective",
             "White", "Non-White", "Shrinking", "Growing"),
  TDI = c(t4c1$coef, t4c2$coef, t4c3$coef, t4c4$coef,
          t4c5$coef, t4c6$coef, t4c7$coef, t4c8$coef),
  SE = c(t4c1$se, t4c2$se, t4c3$se, t4c4$se,
         t4c5$se, t4c6$se, t4c7$se, t4c8$se),
  Stars = sapply(c(t4c1$p, t4c2$p, t4c3$p, t4c4$p,
                   t4c5$p, t4c6$p, t4c7$p, t4c8$p), stars),
  N = c(t4c1$n, t4c2$n, t4c3$n, t4c4$n,
        t4c5$n, t4c6$n, t4c7$n, t4c8$n),
  Mean_Y = c(t4c1$mean_y, t4c2$mean_y, t4c3$mean_y, t4c4$mean_y,
             t4c5$mean_y, t4c6$mean_y, t4c7$mean_y, t4c8$mean_y),
  N_Univs = c(t4c1$n_univs, t4c2$n_univs, t4c3$n_univs, t4c4$n_univs,
              t4c5$n_univs, t4c6$n_univs, t4c7$n_univs, t4c8$n_univs),
  Controls = c("No", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes"),
  # Paper targets for comparison
  Paper_TDI = c(-0.16, -0.16, -0.27, -0.07, -0.31, -0.10, -0.16, -0.12),
  Paper_SE = c(0.10, 0.08, 0.20, 0.09, 0.23, 0.14, 0.10, 0.37),
  Paper_N = c(2896, 2896, 1313, 1583, 2896, 2896, 1567, 1329)
)

write_xlsx(table4, file.path(out_dir, "table_4_enrollments.xlsx"))

cat("\nTable 4 Results vs Paper:\n")
cat(sprintf("  %-5s %-16s  %7s %6s %6s  | %7s %6s %6s\n",
            "Col", "Sample", "TDI", "SE", "N", "Paper", "SE", "N"))
cat("  ", strrep("-", 70), "\n")
for (i in 1:8) {
  cat(sprintf("  %-5s %-16s  %7.3f %6.3f %6d  | %7.2f %6.2f %6d\n",
              table4$Column[i], table4$Sample[i],
              table4$TDI[i], table4$SE[i], table4$N[i],
              table4$Paper_TDI[i], table4$Paper_SE[i], table4$Paper_N[i]))
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE 5: Teacher Preparation Program Graduates
# DV = log(teacher prep completions + 1) = l_ctotalt
# Treatment = test_index_lag (1-year lag)
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE 5: Teacher Preparation Graduates\n")
cat(strrep("=", 60), "\n")

# For TWFE regressions, need non-NA test_index AND test_index_lag
grad_twfe <- graduation_data %>%
  filter(!is.na(test_index), !is.na(test_index_lag))

cat("  Graduation total:", nrow(graduation_data), "obs\n")
cat("  Graduation TWFE sample (non-NA lag):", nrow(grad_twfe), "obs\n")

# Col (1): All, no controls
cat("  Running col (1): All, no controls...\n")
t5c1 <- run_twfe(grad_twfe, "l_ctotalt", "test_index_lag", controls = NULL)

# Col (2): All, full controls
cat("  Running col (2): All, full controls...\n")
t5c2 <- run_twfe(grad_twfe, "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (3): Less Selective, full controls
cat("  Running col (3): Less Selective...\n")
t5c3 <- run_twfe(grad_twfe %>% filter(selective == 0),
                  "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (4): More Selective, full controls
cat("  Running col (4): More Selective...\n")
t5c4 <- run_twfe(grad_twfe %>% filter(selective == 1),
                  "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (5): White graduates, 2012 onward
cat("  Running col (5): White (2012+)...\n")
grad_white_2012 <- grad_twfe %>%
  filter(year >= 2012, !is.na(l_cwhitt))
t5c5 <- run_twfe(grad_white_2012, "l_cwhitt", "test_index_lag", controls = all_controls)

# Col (6): Non-White graduates, 2012 onward
cat("  Running col (6): Non-White (2012+)...\n")
nw_grad_var <- if ("l_cnonwhite" %in% names(grad_twfe)) "l_cnonwhite" else "l_cnwhit"
grad_nw_2012 <- grad_twfe %>%
  filter(year >= 2012, !is.na(.data[[nw_grad_var]]))
t5c6 <- run_twfe(grad_nw_2012, nw_grad_var, "test_index_lag", controls = all_controls)

# Col (7): Shrinking State, full controls
cat("  Running col (7): Shrinking State...\n")
t5c7 <- run_twfe(grad_twfe %>% filter(shrinking == 1),
                  "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (8): Growing State, full controls
cat("  Running col (8): Growing State...\n")
t5c8 <- run_twfe(grad_twfe %>% filter(shrinking == 0),
                  "l_ctotalt", "test_index_lag", controls = all_controls)

# Compile Table 5 results
table5 <- tibble(
  Column = paste0("(", 1:8, ")"),
  Sample = c("All", "All", "Less Selective", "More Selective",
             "White", "Non-White", "Shrinking", "Growing"),
  TDI = c(t5c1$coef, t5c2$coef, t5c3$coef, t5c4$coef,
          t5c5$coef, t5c6$coef, t5c7$coef, t5c8$coef),
  SE = c(t5c1$se, t5c2$se, t5c3$se, t5c4$se,
         t5c5$se, t5c6$se, t5c7$se, t5c8$se),
  Stars = sapply(c(t5c1$p, t5c2$p, t5c3$p, t5c4$p,
                   t5c5$p, t5c6$p, t5c7$p, t5c8$p), stars),
  N = c(t5c1$n, t5c2$n, t5c3$n, t5c4$n,
        t5c5$n, t5c6$n, t5c7$n, t5c8$n),
  Mean_Y = c(t5c1$mean_y, t5c2$mean_y, t5c3$mean_y, t5c4$mean_y,
             t5c5$mean_y, t5c6$mean_y, t5c7$mean_y, t5c8$mean_y),
  N_Univs = c(t5c1$n_univs, t5c2$n_univs, t5c3$n_univs, t5c4$n_univs,
              t5c5$n_univs, t5c6$n_univs, t5c7$n_univs, t5c8$n_univs),
  Controls = c("No", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes"),
  Paper_TDI = c(-0.20, -0.22, -0.28, -0.20, -0.24, -0.19, -0.22, -0.18),
  Paper_SE = c(0.11, 0.10, 0.20, 0.07, 0.05, 0.08, 0.06, 0.12),
  Paper_N = c(5748, 5748, 2617, 3131, 4325, 4309, 3098, 2650)
)

write_xlsx(table5, file.path(out_dir, "table_5_graduations.xlsx"))

cat("\nTable 5 Results vs Paper:\n")
cat(sprintf("  %-5s %-16s  %7s %6s %6s  | %7s %6s %6s\n",
            "Col", "Sample", "TDI", "SE", "N", "Paper", "SE", "N"))
cat("  ", strrep("-", 70), "\n")
for (i in 1:8) {
  cat(sprintf("  %-5s %-16s  %7.3f %6.3f %6d  | %7.2f %6.2f %6d\n",
              table5$Column[i], table5$Sample[i],
              table5$TDI[i], table5$SE[i], table5$N[i],
              table5$Paper_TDI[i], table5$Paper_SE[i], table5$Paper_N[i]))
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE A1: Robustness Tests Using Alternative Samples
# Panel A = Enrollments (l_eftotlt, test_index), Panel B = Graduations (l_ctotalt, test_index_lag)
# All columns use full controls
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE A1: Robustness Tests\n")
cat(strrep("=", 60), "\n")

# Col (1): Main sample (verification — should match Table 4/5 col 2)
cat("  Running col (1): Main sample...\n")
a1_e1 <- t4c2  # reuse Table 4 col 2
a1_g1 <- t5c2  # reuse Table 5 col 2

# Col (2): Include ND
cat("  Running col (2): Include ND...\n")
enroll_inclND <- prepare_enrollment_data(enrollment_raw, exclude_states = c("TN"))
a1_e2 <- run_twfe(enroll_inclND, "l_eftotlt", "test_index", controls = all_controls)

grad_inclND <- prepare_graduation_data(graduation_raw, enrollment_raw, exclude_states = c("TN"))
grad_inclND_twfe <- grad_inclND %>% filter(!is.na(test_index), !is.na(test_index_lag))
a1_g2 <- run_twfe(grad_inclND_twfe, "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (3): Include TN
cat("  Running col (3): Include TN...\n")
enroll_inclTN <- prepare_enrollment_data(enrollment_raw, exclude_states = c("ND"))
a1_e3 <- run_twfe(enroll_inclTN, "l_eftotlt", "test_index", controls = all_controls)

grad_inclTN <- prepare_graduation_data(graduation_raw, enrollment_raw, exclude_states = c("ND"))
grad_inclTN_twfe <- grad_inclTN %>% filter(!is.na(test_index), !is.na(test_index_lag))
a1_g3 <- run_twfe(grad_inclTN_twfe, "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (4): Exclude LA (TDI decreased for LA)
cat("  Running col (4): Exclude LA...\n")
a1_e4 <- run_twfe(enrollment_data %>% filter(State != "LA"),
                    "l_eftotlt", "test_index", controls = all_controls)
a1_g4 <- run_twfe(grad_twfe %>% filter(State != "LA"),
                    "l_ctotalt", "test_index_lag", controls = all_controls)

# Col (5): Drop ME, SC, PA in later years
# Stata robustness do file DROPS observations for these states:
#   Enrollments: ME, PA, SC where year >= 2018 (biennial data, last year affected)
#   Graduations: ME, PA, SC where year >= 2020 (annual data, last year affected)
cat("  Running col (5): Drop ME, SC, PA (late years)...\n")
fix_states <- c("ME", "SC", "PA")

enroll_drop_mespa <- enrollment_data %>%
  filter(!(State %in% fix_states & year >= 2018))
a1_e5 <- run_twfe(enroll_drop_mespa, "l_eftotlt", "test_index", controls = all_controls)

grad_drop_mespa <- grad_twfe %>%
  filter(!(State %in% fix_states & year >= 2020))
a1_g5 <- run_twfe(grad_drop_mespa, "l_ctotalt", "test_index_lag", controls = all_controls)

# Compile Table A1
table_a1 <- tibble(
  Column = rep(paste0("(", 1:5, ")"), 2),
  Panel = c(rep("A: Enrollments", 5), rep("B: Graduations", 5)),
  Sample = rep(c("Main", "Include ND", "Include TN", "Exclude LA", "Fix ME,SC,PA"), 2),
  TDI = c(a1_e1$coef, a1_e2$coef, a1_e3$coef, a1_e4$coef, a1_e5$coef,
          a1_g1$coef, a1_g2$coef, a1_g3$coef, a1_g4$coef, a1_g5$coef),
  SE = c(a1_e1$se, a1_e2$se, a1_e3$se, a1_e4$se, a1_e5$se,
         a1_g1$se, a1_g2$se, a1_g3$se, a1_g4$se, a1_g5$se),
  N = c(a1_e1$n, a1_e2$n, a1_e3$n, a1_e4$n, a1_e5$n,
        a1_g1$n, a1_g2$n, a1_g3$n, a1_g4$n, a1_g5$n),
  Mean_Y = c(a1_e1$mean_y, a1_e2$mean_y, a1_e3$mean_y, a1_e4$mean_y, a1_e5$mean_y,
             a1_g1$mean_y, a1_g2$mean_y, a1_g3$mean_y, a1_g4$mean_y, a1_g5$mean_y),
  N_Univs = c(a1_e1$n_univs, a1_e2$n_univs, a1_e3$n_univs, a1_e4$n_univs, a1_e5$n_univs,
              a1_g1$n_univs, a1_g2$n_univs, a1_g3$n_univs, a1_g4$n_univs, a1_g5$n_univs),
  Paper_TDI = c(-0.16, -0.14, -0.14, -0.15, -0.15,
                -0.22, -0.21, -0.22, -0.24, -0.22),
  Paper_SE = c(0.08, 0.08, 0.08, 0.10, 0.07,
               0.10, 0.11, 0.09, 0.08, 0.10),
  Paper_N = c(2896, 2967, 3127, 2773, 2751,
              5748, 5892, 6203, 5504, 5619)
)

write_xlsx(table_a1, file.path(out_dir, "table_A1_robustness.xlsx"))

cat("\nTable A1 Results vs Paper:\n")
for (panel in c("A: Enrollments", "B: Graduations")) {
  cat("  Panel ", panel, ":\n")
  rows <- table_a1 %>% filter(Panel == panel)
  cat(sprintf("    %-5s %-14s  %7s %6s %6s  | %7s %6s %6s\n",
              "Col", "Sample", "TDI", "SE", "N", "Paper", "SE", "N"))
  cat("    ", strrep("-", 65), "\n")
  for (i in 1:nrow(rows)) {
    cat(sprintf("    %-5s %-14s  %7.3f %6.3f %6d  | %7.2f %6.2f %6d\n",
                rows$Column[i], rows$Sample[i],
                rows$TDI[i], rows$SE[i], rows$N[i],
                rows$Paper_TDI[i], rows$Paper_SE[i], rows$Paper_N[i]))
  }
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE A2 / EVENT STUDY REGRESSIONS (Figures 3 & 4)
# Uses pre-computed year_XXXX columns = DeltaTDI * I(year=t)
# Full controls, clustered at state level
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE A2 / EVENT STUDY REGRESSIONS\n")
cat(strrep("=", 60), "\n")

run_event_study <- function(df, outcome, controls = NULL, years = NULL,
                            bootstrap = USE_BOOTSTRAP, B = BOOT_REPS, seed = BOOT_SEED) {

  # Build formula using pre-computed year_XXXX interaction columns
  year_vars <- paste0("year_", years)
  year_vars <- year_vars[year_vars %in% names(df)]

  valid_controls <- if (!is.null(controls)) controls[controls %in% names(df)] else NULL
  if (!is.null(valid_controls)) {
    valid_controls <- valid_controls[sapply(valid_controls, function(v) !all(is.na(df[[v]])))]
    if (length(valid_controls) == 0) valid_controls <- NULL
  }

  if (is.null(valid_controls) || length(valid_controls) == 0) {
    fml <- as.formula(paste0(outcome, " ~ ", paste(year_vars, collapse = " + "),
                              " | unitid_f + year_f"))
  } else {
    control_str <- paste(valid_controls, collapse = " + ")
    fml <- as.formula(paste0(outcome, " ~ ", paste(year_vars, collapse = " + "),
                              " + ", control_str, " | unitid_f + year_f"))
  }

  # Run main regression with standard clustered SEs
  mod <- feols(fml, data = df, cluster = ~State)

  # Extract coefficients with standard clustered SEs
  results <- tidy(mod, conf.int = TRUE) %>%
    filter(grepl("^year_", term)) %>%
    mutate(year = as.numeric(gsub("year_", "", term))) %>%
    select(year, estimate, std.error, conf.low, conf.high)

  # Pairs cluster bootstrap for all year coefficients simultaneously
  if (bootstrap) {
    cat("    Running pairs cluster bootstrap (B=", B, ")...\n")
    set.seed(seed)
    clusters <- unique(df$State)
    G <- length(clusters)
    cluster_list <- split(df, df$State)

    boot_matrix <- matrix(NA, nrow = B, ncol = length(year_vars))
    colnames(boot_matrix) <- year_vars
    n_failures <- 0

    for (b in 1:B) {
      sampled <- sample(clusters, G, replace = TRUE)
      boot_data <- do.call(rbind, cluster_list[as.character(sampled)])
      rownames(boot_data) <- NULL

      tryCatch({
        boot_mod <- feols(fml, data = boot_data)
        bc <- coef(boot_mod)
        for (yv in year_vars) {
          if (yv %in% names(bc)) boot_matrix[b, yv] <- bc[yv]
        }
      }, error = function(e) {
        n_failures <<- n_failures + 1
      })
    }

    if (n_failures > 0) {
      cat(sprintf("    Note: %d/%d bootstrap reps failed\n", n_failures, B))
    }

    # Update SEs from bootstrap distribution
    for (i in seq_len(nrow(results))) {
      yv <- paste0("year_", results$year[i])
      valid_coefs <- boot_matrix[, yv][!is.na(boot_matrix[, yv])]
      if (length(valid_coefs) > B * 0.5) {
        boot_se <- sd(valid_coefs)
        results$std.error[i] <- boot_se
        z_crit <- qnorm(0.975)
        results$conf.low[i] <- results$estimate[i] - z_crit * boot_se
        results$conf.high[i] <- results$estimate[i] + z_crit * boot_se
      }
    }
  }

  list(model = mod, results = results, n = mod$nobs)
}

# Enrollment event study (biennial, ref=2012)
cat("  Running enrollment event study...\n")
enroll_es <- run_event_study(
  enrollment_data, "l_eftotlt",
  controls = all_controls,
  years = c(2008, 2010, 2014, 2016, 2018)
)

enroll_es_results <- enroll_es$results %>%
  add_row(year = 2012, estimate = 0, std.error = 0, conf.low = 0, conf.high = 0) %>%
  arrange(year)

write_csv(enroll_es_results, file.path(out_dir, "composite_enrollments_event_study_total.csv"))

# Graduation event study (annual, ref=2012)
# Note: event study uses test_index (not lagged) so uses full graduation_data
cat("  Running graduation event study...\n")
grad_es <- run_event_study(
  graduation_data, "l_ctotalt",
  controls = all_controls,
  years = c(2009, 2010, 2011, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)
)

grad_es_results <- grad_es$results %>%
  add_row(year = 2012, estimate = 0, std.error = 0, conf.low = 0, conf.high = 0) %>%
  arrange(year)

write_csv(grad_es_results, file.path(out_dir, "composite_graduations_event_study_total.csv"))

# Table A2 formatted output
table_a2_enroll <- enroll_es_results %>%
  filter(year != 2012) %>%
  mutate(Panel = "Enrollments")

table_a2_grad <- grad_es_results %>%
  filter(year != 2012) %>%
  mutate(Panel = "Graduations")

# Paper targets for Table A2
paper_a2 <- tribble(
  ~year, ~Panel, ~Paper_Est, ~Paper_SE,
  2008, "Enrollments", -0.08, 0.16,
  2010, "Enrollments", -0.07, 0.13,
  2014, "Enrollments", -0.19, 0.07,
  2016, "Enrollments", -0.22, 0.08,
  2018, "Enrollments", -0.17, 0.13,
  2009, "Graduations", -0.06, 0.16,
  2010, "Graduations", -0.10, 0.15,
  2011, "Graduations", -0.07, 0.09,
  2013, "Graduations", -0.06, 0.13,
  2014, "Graduations", -0.06, 0.15,
  2015, "Graduations", -0.23, 0.16,
  2016, "Graduations", -0.26, 0.12,
  2017, "Graduations", -0.29, 0.13,
  2018, "Graduations", -0.28, 0.09,
  2019, "Graduations", -0.39, 0.09,
  2020, "Graduations", -0.37, 0.14
)

table_a2 <- bind_rows(table_a2_enroll, table_a2_grad) %>%
  left_join(paper_a2, by = c("year", "Panel"))

write_xlsx(table_a2, file.path(out_dir, "table_A2_event_study.xlsx"))

cat("\nTable A2 Event Study Results vs Paper:\n")
cat("  Enrollment event study (N=", enroll_es$n, ", paper=2,896):\n")
for (i in 1:nrow(table_a2_enroll)) {
  r <- table_a2 %>% filter(Panel == "Enrollments") %>% slice(i)
  cat(sprintf("    %4d:  %6.3f (%5.3f)  |  %6.2f (%4.2f)\n",
              r$year, r$estimate, r$std.error,
              r$Paper_Est, r$Paper_SE))
}

cat("  Graduation event study (N=", grad_es$n, ", paper=5,784):\n")
for (i in 1:nrow(table_a2_grad)) {
  r <- table_a2 %>% filter(Panel == "Graduations") %>% slice(i)
  cat(sprintf("    %4d:  %6.3f (%5.3f)  |  %6.2f (%4.2f)\n",
              r$year, r$estimate, r$std.error,
              r$Paper_Est, r$Paper_SE))
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE 6: Robustness Tests Using Alternative Measures of TDI
# Panel A = Enrollments, Panel B = Graduations
# Cols: (1) Math, (2) Reading, (3) Writing, (4) Binding Test, (5) ETS Min Scores
# Same controls as Tables 4-5, university + year FE, state clustering
#
# Data source: pre-merged subject-specific xlsx files from
#   JHR R and R/data/composite treatment/ (matching Stata test_by_test_regressions.do)
# Treatment variable: z_score_composite (cols 1-4), test_index_noncomposite (col 5)
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE 6: Alternative TDI Measures\n")
cat(strrep("=", 60), "\n")

# Helper: prepare a subject-specific enrollment dataset from pre-merged xlsx
# Following test_by_test_regressions.do exactly
prepare_t6_enrollment <- function(xlsx_path, treat_var = "z_score_composite") {
  df <- read_excel(xlsx_path)
  df %>%
    # Drop ND and TN
    filter(!State %in% EXCLUDE_STATES) %>%
    # Fill forward treatment within unitid (sorted by unitid, year)
    arrange(unitid, year) %>%
    group_by(unitid) %>%
    fill(!!sym(treat_var), .direction = "down") %>%
    ungroup() %>%
    # Set CT post-2017 to NA (Stata: replace z_score_composite = . if State=="CT" & year>2017)
    mutate(!!treat_var := if_else(State == "CT" & year > 2017, NA_real_, .data[[treat_var]])) %>%
    # NOTE: SC is NOT adjusted for subject-specific regressions (commented out in Stata)
    # Create factors for FE
    mutate(
      unitid_f = factor(unitid),
      year_f = factor(year),
      state_f = factor(State)
    ) %>%
    # Filter on non-NA treatment
    filter(!is.na(.data[[treat_var]]))
}

# Helper: prepare a subject-specific graduation dataset from pre-merged xlsx
# Following test_by_test_regressions.do exactly
prepare_t6_graduation <- function(xlsx_path, treat_var = "z_score_composite") {
  lag_var <- paste0(treat_var, "_lag")
  df <- read_excel(xlsx_path)
  df %>%
    # Drop ND and TN
    filter(!State %in% EXCLUDE_STATES) %>%
    # Regenerate l_cwhitt and nonwhite (Stata drops and recreates these)
    mutate(
      nonwhite = ctotalt - cwhitt,
      l_nonwhite = log(nonwhite + 1),
      l_cwhitt = log(cwhitt + 1)
    ) %>%
    # Fill forward treatment within unitid (sorted by unitid, year)
    arrange(unitid, year) %>%
    group_by(unitid) %>%
    fill(!!sym(treat_var), .direction = "down") %>%
    ungroup() %>%
    # Set CT post-2017 to NA
    mutate(!!treat_var := if_else(State == "CT" & year > 2017, NA_real_, .data[[treat_var]])) %>%
    # NOTE: SC is NOT adjusted for subject-specific regressions (commented out in Stata)
    # Create university-level lag (Stata: xtset unitid year; gen lag = L.treatment)
    arrange(unitid, year) %>%
    group_by(unitid) %>%
    mutate(!!lag_var := if_else(is.na(lag(.data[[treat_var]], 1)),
                                 .data[[treat_var]],
                                 lag(.data[[treat_var]], 1))) %>%
    ungroup() %>%
    # Create factors and ensure l_ctotalt exists
    mutate(
      unitid_f = factor(unitid),
      year_f = factor(year),
      state_f = factor(State),
      l_ctotalt = log(ctotalt + 1)
    ) %>%
    # Filter on non-NA lagged treatment
    filter(!is.na(.data[[lag_var]]))
}

# Data directory for pre-merged composite treatment files
t6_data_dir <- "data/raw/composite_treatment"

t6_names <- c("Math Test", "Reading Test", "Writing Test", "Binding Test", "ETS Min Scores")

# --- Panel A: Enrollments ---
cat("  Panel A: Enrollments\n")
t6_enroll_results <- list()

# Cols 1-4: subject-specific z_score_composite from subject xlsx files
t6_enroll_files <- c("enrollments_event_data_math.xlsx",
                      "enrollments_event_data_reading.xlsx",
                      "enrollments_event_data_writing.xlsx",
                      "enrollments_event_data_binding.xlsx")

for (i in 1:4) {
  cat("  Panel A col (", i, "):", t6_names[i], "...\n")
  enroll_subj <- prepare_t6_enrollment(
    file.path(t6_data_dir, t6_enroll_files[i]), "z_score_composite"
  )
  t6_enroll_results[[i]] <- run_twfe(enroll_subj, "l_eftotlt", "z_score_composite",
                                      controls = all_controls)
}

# Col 5: ETS Min Scores = test_index_noncomposite from main enrollments file
cat("  Panel A col ( 5 ):", t6_names[5], "...\n")
enroll_ets <- prepare_t6_enrollment(
  file.path(t6_data_dir, "enrollments_event_data.xlsx"), "test_index_noncomposite"
)
t6_enroll_results[[5]] <- run_twfe(enroll_ets, "l_eftotlt", "test_index_noncomposite",
                                    controls = all_controls)

cat("  Enrollment T6 sample (col 1):", t6_enroll_results[[1]]$n, "obs\n")

# --- Panel B: Graduations ---
cat("  Panel B: Graduations\n")
t6_grad_results <- list()

# Cols 1-4: subject-specific z_score_composite from subject xlsx files
t6_grad_files <- c("graduation_event_data_math.xlsx",
                    "graduation_event_data_reading.xlsx",
                    "graduation_event_data_writing.xlsx",
                    "graduation_event_data_binding.xlsx")

for (i in 1:4) {
  cat("  Panel B col (", i, "):", t6_names[i], "...\n")
  grad_subj <- prepare_t6_graduation(
    file.path(t6_data_dir, t6_grad_files[i]), "z_score_composite"
  )
  t6_grad_results[[i]] <- run_twfe(grad_subj, "l_ctotalt", "z_score_composite_lag",
                                    controls = all_controls)
}

# Col 5: ETS Min Scores = test_index_noncomposite from main graduation file
cat("  Panel B col ( 5 ):", t6_names[5], "...\n")
grad_ets <- prepare_t6_graduation(
  file.path(t6_data_dir, "graduation_event_data.xlsx"), "test_index_noncomposite"
)
t6_grad_results[[5]] <- run_twfe(grad_ets, "l_ctotalt", "test_index_noncomposite_lag",
                                  controls = all_controls)

cat("  Graduation T6 sample (col 1):", t6_grad_results[[1]]$n, "obs\n")

# Compile Table 6
table6 <- tibble(
  Column = rep(paste0("(", 1:5, ")"), 2),
  Panel = c(rep("A: Enrollments", 5), rep("B: Graduations", 5)),
  Measure = rep(t6_names, 2),
  TDI = c(sapply(t6_enroll_results, `[[`, "coef"),
          sapply(t6_grad_results, `[[`, "coef")),
  SE = c(sapply(t6_enroll_results, `[[`, "se"),
         sapply(t6_grad_results, `[[`, "se")),
  Stars = sapply(c(sapply(t6_enroll_results, `[[`, "p"),
                   sapply(t6_grad_results, `[[`, "p")), stars),
  N = c(sapply(t6_enroll_results, `[[`, "n"),
        sapply(t6_grad_results, `[[`, "n")),
  Mean_Y = c(sapply(t6_enroll_results, `[[`, "mean_y"),
             sapply(t6_grad_results, `[[`, "mean_y")),
  N_Univs = c(sapply(t6_enroll_results, `[[`, "n_univs"),
              sapply(t6_grad_results, `[[`, "n_univs")),
  Paper_TDI = c(-0.17, -0.12, -0.11, -0.16, -0.17,
                -0.21, -0.16, -0.18, -0.29, -0.08),
  Paper_SE = c(0.12, 0.07, 0.06, 0.08, 0.08,
               0.08, 0.08, 0.09, 0.09, 0.11),
  Paper_N = c(2880, 2880, 2880, 2880, 2880,
              5748, 5748, 5748, 5748, 5748)
)

write_xlsx(table6, file.path(out_dir, "table_6_alt_tdi.xlsx"))

cat("\nTable 6 Results vs Paper:\n")
for (panel in c("A: Enrollments", "B: Graduations")) {
  cat("  Panel ", panel, ":\n")
  rows <- table6 %>% filter(Panel == panel)
  cat(sprintf("    %-5s %-14s  %7s %6s %6s  | %7s %6s %6s\n",
              "Col", "Measure", "TDI", "SE", "N", "Paper", "SE", "N"))
  cat("    ", strrep("-", 65), "\n")
  for (i in 1:nrow(rows)) {
    cat(sprintf("    %-5s %-14s  %7.3f %6.3f %6d  | %7.2f %6.2f %6d\n",
                rows$Column[i], rows$Measure[i],
                rows$TDI[i], rows$SE[i], rows$N[i],
                rows$Paper_TDI[i], rows$Paper_SE[i], rows$Paper_N[i]))
  }
}

# ──────────────────────────────────────────────────────────────────────────────
# TABLE 7: Title II Enrollments and Graduations
# Panel A = Enrollments (log totalenrollment + 1)
# Panel B = Completers (log completerscurrent + 1)
# FE: program + year (NOT university + year)
# Controls: State economic + state policy ONLY (no university controls)
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("TABLE 7: Title II Enrollments and Graduations\n")
cat(strrep("=", 60), "\n")

# Read Title II data (pre-merged with treatment vars)
title2 <- read_excel("data/raw/title_ii/title_II_completer_clean.xlsx")

# Rename year column and set up factors
title2 <- title2 %>%
  rename(year = ipeds_completion_year) %>%
  mutate(
    year_f = factor(year),
    program_f = factor(program_f),
    State = as.character(State),
    # Construct log DVs — match Stata: log(x + 1), let NaN/NA rows be dropped
    log_totalenrollment = log(totalenrollment + 1),
    log_whiteenrollment = log(whiteenrollment + 1),
    log_nonwhite_enrollment = log(totalenrollment - whiteenrollment + 1)
  )

# Check/construct shrinking variable
if ("shrinking_state" %in% names(title2) && !"shrinking" %in% names(title2)) {
  title2 <- title2 %>% mutate(shrinking = as.integer(shrinking_state))
} else if (!"shrinking" %in% names(title2)) {
  title2 <- title2 %>% mutate(shrinking = as.integer(State %in% SHRINKING_STATES))
}

cat("  Title II data:", nrow(title2), "obs,", n_distinct(title2$program_f), "programs,",
    n_distinct(title2$State), "states\n")
cat("  Program types:", paste(sort(unique(title2$programtype)), collapse=", "), "\n")

# Controls for Title II: state-level only (no university controls)
t7_controls <- c(state_demo, state_kraft)

# --- Panel A: Enrollments ---
trad_enroll <- title2 %>%
  filter(programtype == "Traditional", !is.na(test_index))

cat("  Traditional enrollment sample:", nrow(trad_enroll), "obs\n")

# Col (1): All traditional
cat("  Panel A col (1): All traditional...\n")
t7_e1 <- run_twfe(trad_enroll, "log_totalenrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (2): Shrinking states (traditional)
cat("  Panel A col (2): Shrinking...\n")
t7_e2 <- run_twfe(trad_enroll %>% filter(shrinking == 1),
                   "log_totalenrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (3): Growing states (traditional)
cat("  Panel A col (3): Growing...\n")
t7_e3 <- run_twfe(trad_enroll %>% filter(shrinking == 0),
                   "log_totalenrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (4): White (traditional)
cat("  Panel A col (4): White...\n")
t7_e4 <- run_twfe(trad_enroll %>% filter(!is.na(log_whiteenrollment)),
                   "log_whiteenrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (5): Non-White (traditional)
cat("  Panel A col (5): Non-White...\n")
t7_e5 <- run_twfe(trad_enroll %>% filter(!is.na(log_nonwhite_enrollment)),
                   "log_nonwhite_enrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (6): Alternative programs (all states)
cat("  Panel A col (6): Alternative...\n")
alt_enroll <- title2 %>%
  filter(programtype == "Alternative", !is.na(test_index))
t7_e6 <- run_twfe(alt_enroll, "log_totalenrollment", "test_index",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# --- Panel B: Completers ---
trad_grad <- title2 %>%
  filter(programtype == "Traditional", !is.na(test_index_lag))

cat("  Traditional completer sample:", nrow(trad_grad), "obs\n")

# Col (1): All traditional
cat("  Panel B col (1): All traditional...\n")
t7_g1 <- run_twfe(trad_grad, "log_completerscurrent", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (2): Shrinking states (traditional)
cat("  Panel B col (2): Shrinking...\n")
t7_g2 <- run_twfe(trad_grad %>% filter(shrinking == 1),
                   "log_completerscurrent", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (3): Growing states (traditional)
cat("  Panel B col (3): Growing...\n")
t7_g3 <- run_twfe(trad_grad %>% filter(shrinking == 0),
                   "log_completerscurrent", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (4): White (traditional)
cat("  Panel B col (4): White...\n")
t7_g4 <- run_twfe(trad_grad %>% filter(!is.na(log_whitecompleters)),
                   "log_whitecompleters", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (5): Non-White (traditional)
cat("  Panel B col (5): Non-White...\n")
t7_g5 <- run_twfe(trad_grad %>% filter(!is.na(log_nonwhite_completers)),
                   "log_nonwhite_completers", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Col (6): Alternative programs (all states)
cat("  Panel B col (6): Alternative...\n")
alt_grad <- title2 %>%
  filter(programtype == "Alternative", !is.na(test_index_lag))
t7_g6 <- run_twfe(alt_grad, "log_completerscurrent", "test_index_lag",
                   controls = t7_controls, fe = "program_f + year_f",
                   unit_var = "program_f")

# Compile Table 7
t7_samples <- c("All", "Shrinking", "Growing", "White", "Non-White", "Alternative")

table7 <- tibble(
  Column = rep(paste0("(", 1:6, ")"), 2),
  Panel = c(rep("A: Enrollments", 6), rep("B: Completers", 6)),
  Sample = rep(t7_samples, 2),
  TDI = c(t7_e1$coef, t7_e2$coef, t7_e3$coef, t7_e4$coef, t7_e5$coef, t7_e6$coef,
          t7_g1$coef, t7_g2$coef, t7_g3$coef, t7_g4$coef, t7_g5$coef, t7_g6$coef),
  SE = c(t7_e1$se, t7_e2$se, t7_e3$se, t7_e4$se, t7_e5$se, t7_e6$se,
         t7_g1$se, t7_g2$se, t7_g3$se, t7_g4$se, t7_g5$se, t7_g6$se),
  Stars = sapply(c(t7_e1$p, t7_e2$p, t7_e3$p, t7_e4$p, t7_e5$p, t7_e6$p,
                   t7_g1$p, t7_g2$p, t7_g3$p, t7_g4$p, t7_g5$p, t7_g6$p), stars),
  N = c(t7_e1$n, t7_e2$n, t7_e3$n, t7_e4$n, t7_e5$n, t7_e6$n,
        t7_g1$n, t7_g2$n, t7_g3$n, t7_g4$n, t7_g5$n, t7_g6$n),
  Mean_Y = c(t7_e1$mean_y, t7_e2$mean_y, t7_e3$mean_y, t7_e4$mean_y, t7_e5$mean_y, t7_e6$mean_y,
             t7_g1$mean_y, t7_g2$mean_y, t7_g3$mean_y, t7_g4$mean_y, t7_g5$mean_y, t7_g6$mean_y),
  N_Programs = c(t7_e1$n_univs, t7_e2$n_univs, t7_e3$n_univs, t7_e4$n_univs, t7_e5$n_univs, t7_e6$n_univs,
                 t7_g1$n_univs, t7_g2$n_univs, t7_g3$n_univs, t7_g4$n_univs, t7_g5$n_univs, t7_g6$n_univs),
  Paper_TDI = c(-0.36, -0.45, -0.02, -0.32, -0.38, -0.09,
                -0.30, -0.32, -0.13, -0.12, -0.36, -0.35),
  Paper_SE = c(0.14, 0.09, 0.28, 0.14, 0.18, 0.35,
               0.11, 0.13, 0.14, 0.11, 0.15, 0.46),
  Paper_N = c(4738, 2562, 2176, 4738, 4715, 1912,
              4742, 2562, 2180, 4742, 4742, 1915)
)

write_xlsx(table7, file.path(out_dir, "table_7_title_II.xlsx"))

cat("\nTable 7 Results vs Paper:\n")
for (panel in c("A: Enrollments", "B: Completers")) {
  cat("  Panel ", panel, ":\n")
  rows <- table7 %>% filter(Panel == panel)
  cat(sprintf("    %-5s %-14s  %7s %6s %6s  | %7s %6s %6s\n",
              "Col", "Sample", "TDI", "SE", "N", "Paper", "SE", "N"))
  cat("    ", strrep("-", 65), "\n")
  for (i in 1:nrow(rows)) {
    cat(sprintf("    %-5s %-14s  %7.3f %6.3f %6d  | %7.2f %6.2f %6d\n",
                rows$Column[i], rows$Sample[i],
                rows$TDI[i], rows$SE[i], rows$N[i],
                rows$Paper_TDI[i], rows$Paper_SE[i], rows$Paper_N[i]))
  }
}

# ──────────────────────────────────────────────────────────────────────────────
# Summary
# ──────────────────────────────────────────────────────────────────────────────

cat("\n", strrep("=", 60), "\n")
cat("REPLICATION SUMMARY\n")
cat(strrep("=", 60), "\n")

cat("\nKey coefficients (R vs Paper):\n")
cat(sprintf("  Table 4 col(1): %.3f vs -0.16  (diff: %.3f)\n", t4c1$coef, t4c1$coef - (-0.16)))
cat(sprintf("  Table 4 col(2): %.3f vs -0.16  (diff: %.3f)\n", t4c2$coef, t4c2$coef - (-0.16)))
cat(sprintf("  Table 5 col(1): %.3f vs -0.20  (diff: %.3f)\n", t5c1$coef, t5c1$coef - (-0.20)))
cat(sprintf("  Table 5 col(2): %.3f vs -0.22  (diff: %.3f)\n", t5c2$coef, t5c2$coef - (-0.22)))

cat("\nN check:\n")
cat(sprintf("  Table 4 col(2): %d vs 2,896  %s\n", t4c2$n,
            if (t4c2$n == 2896) "MATCH" else "MISMATCH"))
cat(sprintf("  Table 5 col(2): %d vs 5,748  %s\n", t5c2$n,
            if (t5c2$n == 5748) "MATCH" else "MISMATCH"))

cat("\nSE comparison note:\n")
if (!USE_BOOTSTRAP) {
  cat("  Using standard clustered SEs (NOT bootstrap).\n")
  cat("  SEs will differ from paper. Set USE_BOOTSTRAP <- TRUE for paper-matching SEs.\n")
} else {
  cat("  Using pairs cluster bootstrap SEs (B=", BOOT_REPS, "), matching Stata.\n")
}

cat(sprintf("\nTable 6 col(1) Panel B: %.3f vs -0.21\n", t6_grad_results[[1]]$coef))
cat(sprintf("Table 7 col(1) Panel A: %.3f vs -0.36\n", t7_e1$coef))
cat(sprintf("Table 7 col(1) Panel B: %.3f vs -0.30\n", t7_g1$coef))

cat("\nOutput files:\n")
cat("  ", file.path(out_dir, "table_4_enrollments.xlsx"), "\n")
cat("  ", file.path(out_dir, "table_5_graduations.xlsx"), "\n")
cat("  ", file.path(out_dir, "table_6_alt_tdi.xlsx"), "\n")
cat("  ", file.path(out_dir, "table_7_title_II.xlsx"), "\n")
cat("  ", file.path(out_dir, "table_A1_robustness.xlsx"), "\n")
cat("  ", file.path(out_dir, "table_A2_event_study.xlsx"), "\n")
cat("  ", file.path(out_dir, "composite_enrollments_event_study_total.csv"), "\n")
cat("  ", file.path(out_dir, "composite_graduations_event_study_total.csv"), "\n")
cat("\n")
